
# FUNCTIONS for RANDOMISATION INFERENCE ---------------------------------------------------------------


library(randomizr); library(ri2)

# FUNCTION THAT RUNS a model on a new dataset and RETURNS THE COEFFICIENT OF INTEREST
output_coeff <- function(df, original_model, coeff, extra_df = NULL) {
  if(!is.null(extra_df)) suppressMessages(df <- bind_rows(df, extra_df))

  suppressMessages(feols_custom(
    original_model$fml,
    data = df,
    cluster = original_model$cluster,
    fixef = original_model$fixef_vars
  )) %>%
    coefficients %>%
    .[[coeff]]
}





# Function to check if any item in a list is null
check_null <- function(myList) {
  anyNull <- any(sapply(myList, is.null))
  return(anyNull)
}


# Function to declare_ra based on the structure of an existing dataset
# For a between-subject treatment, it will use stratum and cluster
# For a within-subject treatment, it will use ind_id and calcuate the number of rows with each treatment for a given ind_id
# (for example, if it's pair_includes_trans it would specify that each individual had 2 where pair_includes_trans = TRUE and 4 where it is false
get_ra_declaration <- function(df, treat, stratum = NULL, cluster = NULL, within = NULL, across = FALSE) {

  # print(cluster)

  # cluster <- "ind_id"

  if (!is.null(within) & (!is.null(cluster) | !is.null(stratum))) stop("Cannot specify both cluster/stratum and within")

  if (!is.null(within)) stratum <- within

  if (across) {
    t_count <- df %>% count(!!sym(treat))
    t_vals <-t_count[[treat]]
    t_n <- t_count$n
    rand_setup <- declare_ra(
      N = nrow(df),
      conditions = as.character(t_vals),
      m_each = t_n
    )

    return(rand_setup)
  }

  block_m_each_prep <- df %>%
    select(any_of(c(treat, cluster, stratum, within)))

  # Remove duplicates in the case of between only
  if (is.null(within)) {
    block_m_each_prep <- block_m_each_prep %>% distinct()
  }

  if (!is.null(stratum)) {
    block_m_each_prep <- block_m_each_prep %>%
      # Make ordereing the same as used by declare_ra
      mutate(!!sym(stratum) := factor(!!sym(stratum), levels = sort(unique(!!sym(stratum))))) %>%
      arrange(!!sym(stratum))

    unique_stratum <- unique(block_m_each_prep[[stratum]])
    unique_treat <- unique(block_m_each_prep[[treat]])
    block_m_each <- matrix(0, nrow = length(unique_stratum), ncol = length(unique_treat))

    for (j in seq_along(unique_stratum)) {
      for (k in seq_along(unique_treat)) {

        block_m_each[j, k] <- sum(block_m_each_prep[[stratum]] == unique_stratum[j] &
                                    block_m_each_prep[[treat]] == unique_treat[k])
      }
    }

    # Declare the randomisation setup
    if (!is.null(within)) {
      rand_setup <- declare_ra(
        N = nrow(df),
        blocks = df[[stratum]],
        block_m_each = block_m_each
      )
    } else {

      # If clustering is specified:
      if (!is.null(cluster)) {
        # Check block_and_cluster_ra_probabilities (was causing errors before, due to incorrect calculation of blcok_m_each)
        block_and_cluster_ra_probabilities(
          blocks = df[[stratum]],
          clusters =  df[[cluster]],
          block_m_each = block_m_each
        )
        # IF THIS IS CAUSING ERRORS - block_m_each is incorrectly specified,
        # probably due to a treatment that varies within a group being incorrectly specified as between

        rand_setup <- declare_ra(
          N = nrow(df),
          blocks = df[[stratum]],
          cluster = df[[cluster]],
          block_m_each = block_m_each
        )
      }

      #   If clustering is not specified
      if (is.null(cluster)) {
        rand_setup <- declare_ra(
          N = nrow(df),
          blocks = df[[stratum]],
          block_m_each = block_m_each
        )
      }

    }
  }



  return(rand_setup)

}


# rand_setup

# FUNCTION that mimics the conduct_ri_test_function in conduct_ri()
# i.e. takes in the details about the function that calcalates the coefficient estimate for a given simulation
# but this one allows us to permute by >1 variable at once
conduct_ri_test_function_custom <- function(test_function, assignments, declarations, outcome = "Y", declaration, sharp_hypothesis = 0, IPW_weights = NULL, sampling_weights = NULL, permutation_matrix = NULL, data, sims = 1000, progress_bar = FALSE) {


  # FOR EACH ASSIGNMENT VAR (x2)
  # Get the permutation matrix


  # For each column of all permutation matrices,
  # For each assignment var
  # reassign in the data based on that
  # Calculate the test function

  # tic("test_stat_obs")
  test_stat_obs <- suppressMessages(suppressWarnings(test_function(data)))

  permutation_matrices <- declarations %>%
    map(obtain_permutation_matrix, maximum_permutations = sims)

?obtain
  assignments_type <- assignments %>% map(~ typeof(data[[.x]]))

  # For each assignment that is numeric, we need to convert the permutation matrix to numeric
  # For each assignment that is character, we need to convert the permutation matrix to character
  # tic("make_matrices")
  for (v in seq_along(assignments)) {
    if (assignments_type[[v]] == "character") {
      permutation_matrices[[v]] <- matrix(as.character(permutation_matrices[[v]]), nrow = nrow(permutation_matrices[[v]]))
    } else if (assignments_type[[v]] == "double") {
      permutation_matrices[[v]] <- matrix(as.numeric(permutation_matrices[[v]]), nrow = nrow(permutation_matrices[[v]]))
    }
  }
  # toc()



  test_stat_sim <- vector("numeric", sims)

  pb <- txtProgressBar(min = 0, max = sims, style = 3)

  for (i in 1:sims) {
    setTxtProgressBar(pb, i)
    for (v in seq_along(assignments)) {
      # Reassign to new treatments
      data[, assignments[[v]]] <- permutation_matrices[[v]][, i]
    }

    if (setequal(assignments, c("r2_reliability_diff", "r2_reliability_shown"))) {
      data <- data %>%
        mutate(r2_reliability_diff = ifelse(!r2_reliability_shown, 0, r2_reliability_diff))
    }

    test_stat_sim[[i]] <- test_function(data)
  }

  close(pb)

  sims_df <- data.frame(est_sim = test_stat_sim, est_obs = test_stat_obs, term = "Custom Test Statistic")
  return(structure(list(sims_df = sims_df), class = "ri"))
}


# FUNCTION that is a custom version of the conduct_ri()
# that allows us to permute by >1 variables at once, but outputs stuff in the same format
conduct_ri_custom <- function(formula = NULL, model_1 = NULL, model_2 = NULL, test_function = NULL, assignments, outcome = NULL, declarations = NULL, sharp_hypothesis = 0, studentize = FALSE, IPW = TRUE, IPW_weights = NULL, sampling_weights = NULL, permutation_matrix = NULL, data, sims = 1000, progress_bar = FALSE, p = "two-tailed") {
  if (!is.null(test_function)) {
    ri_out <- conduct_ri_test_function_custom(test_function = test_function, assignments = assignments, outcome = outcome, declarations = declarations, sharp_hypothesis = sharp_hypothesis, IPW_weights = IPW_weights, sampling_weights = sampling_weights, permutation_matrix = permutation_matrix, data = data, sims = sims, progress_bar = progress_bar)
  }
  if (is.null(formula) & is.null(model_1) & is.null(model_2) & is.null(test_function)) {
    stop("You must specify either a formula, models 1 and 2, or a test function.")
  }
  ri_out$sims_df <- within(ri_out$sims_df, {
    est_sim <- round(est_sim, 10)
    est_obs <- round(est_obs, 10)
  })
  ri_out$p <- p
  ri_out$sharp_hypothesis <- sharp_hypothesis
  return(ri_out)
}


# RUNS RANDOMISATION INFERENCE BASED ON THE SPECIFICATION OF WHAT TYPE OF VARIABLE IT IS
# At the moment there are two types:
# 1.  Between-subject variables, which are randomised so that everyone in the same cluster have the same value (e.g. all people
#     in a group see the same video), and uses the stratum to work out how many in a block
# 2.  Within-subject variables, which are assumed to be "stratified" at the individual level, e.g. always have 2 trans photos
#     for each individual

# In the future - may need to add more (e.g. due to the listener, or variables that are randomised
# without blocking at the individual level e.g. item differences)

# NOTE: this setup requires all the variables to be NUMERIC, otherwise the syntax will break.
# They need to be in numeric in the model and the df
ri_custom_v2 <- function(df, model, n_sims, stratum, cluster,
                         coef_keep = NULL, coef_omit = NULL, var_types,
                         control_group = NULL,
                         same_group_spec = NULL,
                         ...) {
  # Calculate number of treatment clusters assigned to each stratum
  # Convert factors to numeric

  # Get the coefficients to calculate for:
  terms <- model %>% tidy() %>% pull(term)
  if (!is.null(coef_keep)) terms <- terms %>% str_subset(coef_keep)
  if (!is.null(coef_omit)) terms <- terms %>% str_subset(coef_omit, negate = TRUE)
  print(paste0("Calculating RI for the following terms: ", paste(terms, collapse = ", ")))
  if (length(terms) == 0) stop("No coefficients to calculate RI for")

  out_sims <- list()
  out_obs <- list()
  out_p <- list()
  out_ri <- list()

  # STREAMLINE DATASET to speed things up a lot (remove extraneous variables)
  vars_fml <- model$fml %>% fml_to_vars()
  vars_cluster <- model$cluster
  vars_fixef <- model$fixef_vars
  if (!is.null(control_group))   vars_control_group <- c(names(control_group), flatten_chr(control_group))
  else                           vars_control_group <- NULL
  vars_model <- c("ind_id", "group_id", stratum, vars_fml, vars_cluster, vars_fixef, vars_control_group)
  df <- df %>% select(all_of(vars_model))

  # For each of the terms we want to calculate RI p-value for
  for (i in seq_along(terms)) {

    # i <- 1

    treat <- terms[[i]]

    # 1. SPECIAL CASE - reliability diff and reliability shown
    if (str_detect(treat, "reliability_diff|reliability_shown")) {

      # Create the function filled in with the right paramters to calculate the coefficient in the RI sim
      output_coeff_specified <- function(df) { output_coeff(df, original_model = model, coeff = terms[[i]]) }

      print(str_glue("Running RI for {treat}, (var_type = special case)"))

      # Declare_ra for reliability diff, by first excluding when it is coded as 0 because it is not shown
      # It's an "across" treatment
      p_treatments <- df %>%
        filter(r2_reliability_shown == 1) %>%
        count_prop(r2_reliability_diff, return_count = TRUE)


      rand_reliability_diff <- declare_ra(
        N = nrow(df),
        conditions = p_treatments %>% pull(r2_reliability_diff),
        prob_each = p_treatments %>% pull(prop)
      )

      # Now do the same for reliability shown -
      # It's a within treatment
      rand_reliability_shown <- get_ra_declaration(df, "r2_reliability_shown", within = "ind_id")

      rand_setup_list <- list(rand_reliability_diff, rand_reliability_shown)
      treat_list <- c("r2_reliability_diff", "r2_reliability_shown")

      ri_out <- conduct_ri_custom(
        test_function = output_coeff_specified,
        declarations = rand_setup_list,
        p = "two-tailed",
        data = df,
        sims = n_sims,
        progress_bar = TRUE,
        assignments = treat_list
      )

    } else {

      # Work out which type of variable it is
      if (str_detect(terms[[i]], "\\:")) {
        # INTERACTION TERMS:
        # Split the variables and get the type for each one
        var_split <- treat %>% str_split("\\:") %>% unlist()
        var_type_list <- var_split %>% map(~ var_types[[.x]])
        if (check_null(var_type_list)) stop(str_glue("Issue with {treat} not found in var_types"))
      } else {
        # NORMAL TERMS:
        var_split <- terms[[i]]
        var_type_list <- var_types[terms[[i]]]

        if (check_null(var_type_list)) stop(str_glue("Issue with {treat} not found in var_types"))
      }

      # Prep the lists to contain the output
      rand_setup_list <- list()
      treat_list <- var_split

      # Exclude any non-random variables from treat list
      treat_list <- treat_list[flatten_chr(var_type_list) != "non_random"]
      var_type_list <- var_type_list[flatten_chr(var_type_list) != "non_random"]

      for (v in seq_along(treat_list)) {

        treat <- treat_list[[v]]
        var_type <- var_type_list[[v]]


        # If there variable (any in interaction) has a control group, specify which control group, and create affected and unaffected dfs
        if (!is.null(control_group)) {
          if (any(treat_list %in% flatten_chr(control_group))) {
            # If the treatment has a corresponding control group:
            control_group_i <- names(control_group[map_lgl(control_group, ~ any(treat_list %in% .x))])
            treat_with_control <- treat_list[treat_list %in% flatten_chr(control_group)]
            if (length(control_group_i) > 1) stop(str_glue("Multiple control groups found for {treat}"))

            # Get the df of people that are neither in treat nor control
            unaffected_df <- df %>% filter(!!sym(treat_with_control) != 1 & !!sym(control_group_i) != 1)
            affected_df <- df %>% filter(!!sym(treat_with_control) == 1 | !!sym(control_group_i) == 1)

            # What is the var_type of the control group?
            var_type_control_group <- var_types[[control_group_i]]

          } else {
            control_group_i <- NULL
            var_type_control_group <- NULL
            unaffected_df <- NULL
            affected_df <- df
          }
        } else {
          #   If there's no control group for this treatment, just use the whole df
          control_group_i <- NULL
          var_type_control_group <- NULL
          unaffected_df <- NULL
          affected_df <- df
        }

        if (is.null(var_type) || var_type == "NULL") stop(str_glue("Variable {terms[[i]]} not found in var_types"))
        print(str_glue("Running RI for {treat}, (var_type = {var_type}, control group = {control_group_i})"))

        # Create the function filled in with the right paramters to calculate the coefficient in the RI sim
        output_coeff_specified <- function(df) { output_coeff(df, original_model = model, coeff = terms[[i]], extra_df = unaffected_df) }

        if (var_type == "between") {
          # 1. If it's a between-subjects variable, use the default cluster and stratum (e.g., use this for discussion or video)
          rand_setup <- get_ra_declaration(affected_df, treat, stratum, cluster)

        } else if (var_type == "within") {
          # 2. If it's a within-subjects variable, use "ind_id" as the within varibale.
          # This assumes that randomisation occured within individuals, e.g. always have 2 trans photos per person
          rand_setup <- get_ra_declaration(affected_df, treat, within = "ind_id")

        } else if (var_type == "across") {
          # 4. Variables that are randomised across all observations (e.g. item_diff, or reliability score)
          #   Rather than within a given individual
          rand_setup <- get_ra_declaration(affected_df, treat, across = TRUE)

        } else if (var_type == "ind") {
          # var_type == "ind" e.g. listener vs speaker
          if (is.null(control_group_i)) stop(str_glue("Variable {treat} is of type {var_type}, so you need to specify the control group"))
          if (is.null(same_group_spec)) stop(str_glue("Variable {treat} is of type {var_type}, so you need to specify which are in the same group"))

          # Check whether they treat and control_group_i in teh same group
          same_group <- same_group_spec %>%
            map_lgl(~ sum(c(treat, control_group_i) %in% .x) == 2) %>%
            any()

          print(str_glue("Same group = {same_group}"))

          # If the control group is between, then should swap between groups (same as between)
          if (var_type_control_group == "between") {
            rand_setup <- get_ra_declaration(affected_df, treat, stratum, cluster)
          } else if (var_type_control_group == "ind") {

            if (is.null(same_group)) stop(str_glue("You need to specify whether it's in the same group as {treat}"))

            # If control group is in same group, then should swap within groups but between individuals
            if (same_group) {
              print("Same group")
              rand_setup <- get_ra_declaration(affected_df, treat, stratum = "group_id", cluster = "ind_id")
            } else {
              #   If control group is in different group, then should swap between groups
              rand_setup <- get_ra_declaration(affected_df, treat, stratum, cluster)
            }

          }

        }

        rand_setup_list <- c(rand_setup_list, rand_setup)

      }

      print(str_glue("treat_list: {paste0(treat_list, collapse = ', ')}"))

      ri_out <- conduct_ri_custom(
        test_function = output_coeff_specified,
        declarations = rand_setup_list,
        p = "two-tailed",
        data = affected_df,
        sims = n_sims,
        progress_bar = TRUE,
        assignments = treat_list
      )


      # print("RI OUT")
    }



    # OUTPUT TO THE LISTS
    out_sims[[i]] <- ri_out[[1]]$est_sim
    out_obs[[i]] <- ri_out[[1]]$est_obs[[1]]
    out_ri[[i]] <- ri_out
    # For non-random variables, change the p-value to NA so that it picks up the original p-value instead
    if (length(treat_list) == 0) {
      out_p[[i]] <- NA
    } else {
      out_p[[i]] <- ri2:::summary.ri(ri_out)$two_tailed_p_value
    }

  }

  # Output list of all outptus
  out <- tibble(
    term = terms,
    est_obs = out_obs %>% unlist(),
    est_sim = out_sims,
    ri_obj = out_ri,
    p = out_p %>% unlist()
  )

  return(out)

}

test <- FALSE

if (test) {




  model_test <- feols_custom(
    r2_choose_comparator ~ item_diff + pair_includes_trans * (video_type_placebo + video_type_treatment + group_label + stratum_id),
    data = r2_choices %>%
      mutate(group_label = as.numeric(group_label == "discuss"),
             r2_choose_comparator = as.numeric(r2_choose_comparator),
             pair_includes_trans = as.numeric(pair_includes_trans),
             video_type_placebo = as.numeric(video_type == "placebo"),
             video_type_treatment = as.numeric(video_type == "treatment")),
    fixef = c("stratum_id", "video_type", "delivery_incentive_exp"),
    cluster = "group_id"
  )

  r2_choices %>% count_prop(item_diff)
  model_test$collin.var
  tidy(model_test)

  test_ri_custom_v2 <- ri_custom_v2(
    df = r2_choices %>%
      mutate(group_label = as.numeric(group_label == "discuss"),
             r2_choose_comparator = as.numeric(r2_choose_comparator),
             pair_includes_trans = as.numeric(pair_includes_trans),
             video_type_placebo = as.numeric(video_type == "placebo"),
             video_type_treatment = as.numeric(video_type == "treatment")),
    model = model_test,
    n_sims = 100,
    # treat = "pair_includes_trans",
    stratum = "stratum_id",
    cluster = "group_id",
    # coef_keep = "(item_diff|pair_includes_trans($|\\:video_type|group_label))",
    coef_keep = "item_diff",
    var_types = list(
      "pair_includes_trans" = "within",
      "group_label" = "between",
      "video_type_placebo" = "between",
      "video_type_treatment" = "between",
      "item_diff" = "across"
    )
  )



  ri_test <- feols_custom(
    r2_choose_comparator ~ item_diff + r2_reliability_diff +  r2_reliability_shown +  pair_includes_trans * (video_type_placebo + video_type_treatment + group_label + stratum_id),
    data = r2_choices %>%
      mutate(group_label = as.numeric(group_label == "discuss"),
             r2_choose_comparator = as.numeric(r2_choose_comparator),
             pair_includes_trans = as.numeric(pair_includes_trans),
             video_type_placebo = as.numeric(video_type == "placebo"),
             video_type_treatment = as.numeric(video_type == "treatment"),
             r2_reliability_shown = as.numeric(r2_reliability_shown)),
    fixef = c("stratum_id", "video_type", "delivery_incentive_exp"),
    cluster = "group_id",
    stratum = "stratum_id",
    var_types = list(
      "pair_includes_trans" = "within",
      "group_label" = "between",
      "video_type_placebo" = "between",
      "video_type_treatment" = "between",
      "item_diff" = "across"
    ),
    coef_keep = "r2_reliability_diff|r2_reliability_shown",
    n_sims = 100,
    ri = TRUE
  )

  tidy_custom.fixest(ri_test)

  ri_test$ri_out

}









